/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

#define DBG_KEY(what, key) do { \
    printf("%s: ", what); \
    for(size_t i = 0; i < 32; ++i) { \
        printf("%d, ", (key)[i]); \
    } \
    printf("\n"); \
} while(0)

#ifdef __CUDA_ARCH__
#define DBG_FE(what, fe) do { \
    FieldElement _fe; \
    fe_copy(_fe, (fe)); \
    fe_reduce(_fe); \
    printf("%s: ", what); \
    for(size_t i = 0; i < FE_SIZE; ++i) \
        printf("0x%x, ", (_fe)[i]); \
    printf("\n"); \
} while(0)
#else
#define DBG_FE(what, fe) do { \
    uint32_t _cfe[8]; \
    fe_to_bytes((uint8_t*)_cfe, fe); \
    printf("%s: ", what); \
    for(size_t _i = 0; _i < 8; ++_i) \
        printf("0x%x, ", _cfe[_i]); \
    printf("\n"); \
} while(0)
#endif

KERNEL_SPEC(precompute_addends_kernel, BLOCK_SIZE)
(
    GLOBAL uint32_t* addends_ypx,
    GLOBAL uint32_t* addends_ymx,
    GLOBAL uint32_t* addends_xy,
    GLOBAL const Bn256* __restrict__ scalars,
    GLOBAL const struct AffineNielsPoint* __restrict__ basepoint_mul_fold_table_,
    GLOBAL const size_t *__restrict__ midstate_iterations
) {
    const size_t gid = global_id();
    if(gid >= *midstate_iterations)
        return;

    const size_t gsize = global_size();

    LOCAL struct AffineNielsPoint basepoint_mul_fold_table[256];
    load_fold_table(basepoint_mul_fold_table, basepoint_mul_fold_table_);
    SYNCTHREADS();

    struct  EdwardsPoint p;
    uint32_t scalar[8];

    // TODO : get rid of the cast
#ifndef __CUDA_ARCH__
    union scopy_t {
        uint32_t u32;
        uint8_t u8[4];
    };

    for(int i = 0; i < 8; ++i) {
        scalar[i] = ((GLOBAL const union scopy_t*)(scalars[gid] + 4 * i))->u32;
    }
#else
    memcpy(scalar, scalars[gid], 32);
#endif

    scalar_mul_folds(&p, scalar, basepoint_mul_fold_table);

    struct AffineNielsPoint a;
    FieldElement x, y, z;
    fe_invert(z, p.z);

    fe_mul(x, p.x, z);
    fe_mul(y, p.y, z);

    fe_add(a.y_plus_x, y, x);
    fe_sub(a.y_minus_x, y, x);
    fe_mul(a.xy2d, x, y);
    fe_mul(a.xy2d, a.xy2d, ed25519_d2);

    fe_copy(addends_ypx + gid * FE_SIZE, a.y_plus_x);
    fe_copy(addends_ymx + gid * FE_SIZE, a.y_minus_x);
    fe_copy(addends_xy + gid * FE_SIZE, a.xy2d);
}